{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "93caa9f6",
   "metadata": {},
   "source": [
    "## Basic Inference (teacher-guided, 8 steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72e5ced3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from infer import DMOInference\n",
    "import IPython.display as ipd\n",
    "import torchaudio\n",
    "import time\n",
    "\n",
    "# Initialize the model\n",
    "tts = DMOInference(\n",
    "    student_checkpoint_path=\"../ckpts/model_85000.pt\", \n",
    "    duration_predictor_path=\"../ckpts/model_1500.pt\",\n",
    "    device=\"cuda\",\n",
    "    model_type=\"F5TTS_Base\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc5b2634",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_audio = \"f5_tts/infer/examples/basic/basic_ref_en.wav\"\n",
    "\n",
    "ref_text = \"Some call me nature, others call me mother nature.\"\n",
    "gen_text = \"I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring.\"\n",
    "\n",
    "start_time = time.time()\n",
    "# Generate with default settings\n",
    "generated_audio = tts.generate(\n",
    "    gen_text=gen_text,\n",
    "    audio_path=prompt_audio,\n",
    "    prompt_text=ref_text)\n",
    "end_time = time.time()\n",
    "\n",
    "processing_time = end_time - start_time\n",
    "audio_duration = generated_audio.shape[-1] / 24000\n",
    "rtf = processing_time / audio_duration\n",
    "\n",
    "print('\\n--------\\n')\n",
    "print('Prompt Audio: ')\n",
    "display(ipd.Audio(prompt_audio, rate=24000))\n",
    "print('Generated Audio: ')\n",
    "display(ipd.Audio(generated_audio, rate=24000))\n",
    "\n",
    "print(f\"  RTF: {rtf:.2f}x ({1/rtf:.2f}x speed)\")\n",
    "print(f\"  Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d65b4bde",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_audio = \"f5_tts/infer/examples/basic/basic_ref_zh.wav\"\n",
    "\n",
    "ref_text = \"对，这就是我，万人敬仰的太乙真人。\"\n",
    "gen_text = '突然，身边一阵笑声。我看着他们，意气风发地挺直了胸膛，甩了甩那稍显肉感的双臂，轻笑道：\"我身上的肉，是为了掩饰我爆棚的魅力，否则，岂不吓坏了你们呢？\"'\n",
    "\n",
    "start_time = time.time()\n",
    "# Generate with default settings\n",
    "generated_audio = tts.generate(\n",
    "    gen_text=gen_text,\n",
    "    audio_path=prompt_audio,\n",
    "    prompt_text=ref_text\n",
    ")\n",
    "end_time = time.time()\n",
    "\n",
    "processing_time = end_time - start_time\n",
    "audio_duration = generated_audio.shape[-1] / 24000\n",
    "rtf = processing_time / audio_duration\n",
    "\n",
    "print('\\n--------\\n')\n",
    "print('Prompt Audio: ')\n",
    "display(ipd.Audio(prompt_audio, rate=24000))\n",
    "print('Generated Audio: ')\n",
    "display(ipd.Audio(generated_audio, rate=24000))\n",
    "\n",
    "print(f\"  RTF: {rtf:.2f}x ({1/rtf:.2f}x speed)\")\n",
    "print(f\"  Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0b8df0f",
   "metadata": {},
   "source": [
    "## Comparision between different sampling configurations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3962e78c",
   "metadata": {},
   "source": [
    "#### Student only (4 steps)\n",
    "\n",
    "Need to set `teacher_steps` and `student_start_step` to 0 to enable full student sampling."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53a232a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_audio = \"f5_tts/infer/examples/basic/basic_ref_zh.wav\"\n",
    "\n",
    "ref_text = \"对，这就是我，万人敬仰的太乙真人。\"\n",
    "gen_text = '突然，身边一阵笑声。我看着他们，意气风发地挺直了胸膛，甩了甩那稍显肉感的双臂，轻笑道：\"我身上的肉，是为了掩饰我爆棚的魅力，否则，岂不吓坏了你们呢？\"'\n",
    "\n",
    "start_time = time.time()\n",
    "# Generate with default settings\n",
    "generated_audio = tts.generate(\n",
    "    gen_text=gen_text,\n",
    "    audio_path=prompt_audio,\n",
    "    prompt_text=ref_text,\n",
    "    teacher_steps=0, # set this to 0 for no teachr sampling\n",
    "    student_start_step=0, # set this to 0 for full student sampling\n",
    ")\n",
    "end_time = time.time()\n",
    "\n",
    "processing_time = end_time - start_time\n",
    "audio_duration = generated_audio.shape[-1] / 24000\n",
    "rtf = processing_time / audio_duration\n",
    "\n",
    "print('\\n--------\\n')\n",
    "print('Prompt Audio: ')\n",
    "display(ipd.Audio(prompt_audio, rate=24000))\n",
    "print('Generated Audio: ')\n",
    "display(ipd.Audio(generated_audio, rate=24000))\n",
    "\n",
    "print(f\"  RTF: {rtf:.2f}x ({1/rtf:.2f}x speed)\")\n",
    "print(f\"  Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cb24808",
   "metadata": {},
   "source": [
    "#### More teacher steps (16 steps)\n",
    "\n",
    "Now we use 14 steps from the teacher and 2 steps from the student to have higher diversity (16 steps total)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10b3620e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "prompt_audio = \"f5_tts/infer/examples/basic/basic_ref_zh.wav\"\n",
    "\n",
    "ref_text = \"对，这就是我，万人敬仰的太乙真人。\"\n",
    "gen_text = '突然，身边一阵笑声。我看着他们，意气风发地挺直了胸膛，甩了甩那稍显肉感的双臂，轻笑道：\"我身上的肉，是为了掩饰我爆棚的魅力，否则，岂不吓坏了你们呢？\"'\n",
    "\n",
    "start_time = time.time()\n",
    "# Generate with default settings\n",
    "generated_audio = tts.generate(\n",
    "    gen_text=gen_text,\n",
    "    audio_path=prompt_audio,\n",
    "    prompt_text=ref_text,\n",
    "    teacher_steps=24, \n",
    "    teacher_stopping_time=0.3, # 0.25 means students go for the last two steps (0.26ish, 0.6ish)\n",
    "    student_start_step=2, # only two steps for students\n",
    "    verbose=True # see the number of steps used\n",
    ")\n",
    "end_time = time.time()\n",
    "\n",
    "processing_time = end_time - start_time\n",
    "audio_duration = generated_audio.shape[-1] / 24000\n",
    "rtf = processing_time / audio_duration\n",
    "\n",
    "print('\\n--------\\n')\n",
    "print('Prompt Audio: ')\n",
    "display(ipd.Audio(prompt_audio, rate=24000))\n",
    "print('Generated Audio: ')\n",
    "display(ipd.Audio(generated_audio, rate=24000))\n",
    "\n",
    "print(f\"  RTF: {rtf:.2f}x ({1/rtf:.2f}x speed)\")\n",
    "print(f\"  Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70fd3cb1",
   "metadata": {},
   "source": [
    "#### Stochastic duration \n",
    "\n",
    "Introduce even more diversity by adding randomness to the duration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb39d323",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_audio = \"f5_tts/infer/examples/basic/basic_ref_zh.wav\"\n",
    "\n",
    "ref_text = \"对，这就是我，万人敬仰的太乙真人。\"\n",
    "gen_text = '突然，身边一阵笑声。我看着他们，意气风发地挺直了胸膛，甩了甩那稍显肉感的双臂，轻笑道：\"我身上的肉，是为了掩饰我爆棚的魅力，否则，岂不吓坏了你们呢？\"'\n",
    "\n",
    "start_time = time.time()\n",
    "# Generate with default settings\n",
    "generated_audio = tts.generate(\n",
    "    gen_text=gen_text,\n",
    "    audio_path=prompt_audio,\n",
    "    prompt_text=ref_text,\n",
    "    teacher_steps=24, \n",
    "    teacher_stopping_time=0.25, # 0.25 means students go for the last two steps (0.26ish, 0.6ish)\n",
    "    student_start_step=2, # only two steps for students\n",
    "    temperature=0.8, # set some temperature for duration sampling \n",
    ")\n",
    "end_time = time.time()\n",
    "\n",
    "processing_time = end_time - start_time\n",
    "audio_duration = generated_audio.shape[-1] / 24000\n",
    "rtf = processing_time / audio_duration\n",
    "\n",
    "print('\\n--------\\n')\n",
    "print('Prompt Audio: ')\n",
    "display(ipd.Audio(prompt_audio, rate=24000))\n",
    "print('Generated Audio: ')\n",
    "display(ipd.Audio(generated_audio, rate=24000))\n",
    "\n",
    "print(f\"  RTF: {rtf:.2f}x ({1/rtf:.2f}x speed)\")\n",
    "print(f\"  Processing: {processing_time:.2f}s for {audio_duration:.2f}s audio\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f109ed3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cc13360",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
